"""
compute_Amu.py
~~~~~~~~~~~~~~

Compute the gauge potential ``A_μ(i)`` for each link ``i`` in a discrete
lattice gauge theory.  The potential is constructed from three pieces:

1.  A per‑link flip count ``n_i`` produced by the companion flip count
    simulator (Volume 1).  Flip counts quantify the number of times a
    context‑dependent operator changes at each link.
2.  A kernel value ``ρ_i`` provided by the Volume 3 diagnostics package or,
    if absent, generated as a constant array of ones.  For non‑Abelian
    gauge groups this scalar kernel is promoted to a matrix by tensoring
    with the appropriate generator of the Lie algebra (e.g. σ_z/2 for SU(2)).
3.  A logistic fractal dimension ``D(n_i)`` and linear pivot weight ``g(D)``.

The logistic dimension is defined by

    D(n) = 1 + 2 / (1 + exp(k * (n - n0)))

where ``k`` and ``n0`` are parameters.  The pivot weight is a linear
function of ``D``: ``g(D) = a * D + b``.  The gauge potential for the
U(1) group is then

    A_i = g * g(D(n_i)) * ρ_i

where ``g`` is the gauge coupling constant.  For SU(2) and SU(3) the
matrix‑valued kernels are multiplied by the same scalar ``g(D(n_i))``.

The results are saved into ``data_dir`` as ``A_U1.npy``, ``A_SU2.npy`` and
``A_SU3.npy`` if the corresponding kernels exist.  Missing kernels are
ignored, allowing the pipeline to run with just U(1) data.
"""

from __future__ import annotations

import os
import yaml
import numpy as np
from typing import Optional


def logistic_dimension(n: np.ndarray, k: float, n0: float) -> np.ndarray:
    """Compute the logistic fractal dimension for an array of flip counts.

    Parameters
    ----------
    n : numpy.ndarray
        Array of flip counts per link.
    k : float
        Logistic slope parameter.  If zero, the logistic reduces to a
        constant 2.
    n0 : float
        Logistic midpoint parameter.

    Returns
    -------
    numpy.ndarray
        Array of logistic dimensions of the same shape as ``n``.
    """
    n = np.asarray(n, dtype=float)
    if k == 0.0:
        # Avoid division by zero; limit k→0 gives D=2 everywhere
        return np.full_like(n, 2.0)
    return 1.0 + 2.0 / (1.0 + np.exp(k * (n - n0)))


def pivot_weight(D: np.ndarray, a: float, b: float) -> np.ndarray:
    """Compute the linear pivot weight from the fractal dimension.

    g(D) = a * D + b

    Parameters
    ----------
    D : numpy.ndarray
        Array of fractal dimensions per link.
    a : float
        Slope of the linear function.
    b : float
        Intercept of the linear function.
    """
    return a * D + b


def _resolve_path(base_dir: str, path_cfg: Optional[str], default_name: str) -> str:
    """Resolve a path relative to a base directory.

    If ``path_cfg`` is absolute it is returned as is.  If it is relative it
    is joined with ``base_dir``.  If it is None the ``default_name`` is
    appended to ``base_dir``.
    """
    if path_cfg:
        return path_cfg if os.path.isabs(path_cfg) else os.path.join(base_dir, path_cfg)
    return os.path.join(base_dir, default_name)


def main(config_path: str = 'config.yaml') -> None:
    """Entry point for computing and writing the gauge potential arrays.

    Reads ``config.yaml`` to determine the lattice size, data and results
    directories, gauge coupling and pivot parameters, then constructs
    ``A_μ`` for the available gauge groups.  Missing kernels simply result
    in the absence of the corresponding output files.
    """
    # Load configuration
    cfg_file = config_path if os.path.isabs(config_path) else os.path.abspath(config_path)
    if not os.path.exists(cfg_file):
        raise FileNotFoundError(f"Cannot find configuration file {config_path}")
    with open(cfg_file) as f:
        cfg = yaml.safe_load(f)

    # Resolve base directory for relative paths
    base_dir = os.path.dirname(cfg_file)
    data_dir_cfg = cfg.get('data_dir', 'data')
    data_dir = data_dir_cfg if os.path.isabs(data_dir_cfg) else os.path.join(base_dir, data_dir_cfg)
    os.makedirs(data_dir, exist_ok=True)

    # Load lattice to determine number of links
    lattice_path = _resolve_path(data_dir, cfg.get('lattice_file', 'lattice.npy'), 'lattice.npy')
    lattice = np.load(lattice_path, allow_pickle=True)
    num_links = len(lattice)

    # Load or create flip counts.  If the array is shorter than num_links it
    # is broadcast cyclically to cover all links.  This enables using a
    # 4×4 flip‑count dataset with an 8×8 lattice, as in the provided
    # experiment instructions.
    flip_counts_path_cfg = cfg.get('flip_counts_path')
    if flip_counts_path_cfg:
        flip_counts_path = flip_counts_path_cfg if os.path.isabs(flip_counts_path_cfg) else os.path.normpath(os.path.join(base_dir, flip_counts_path_cfg))
        flip_counts = np.load(flip_counts_path, allow_pickle=True)
    else:
        flip_counts = np.zeros(num_links, dtype=float)
    flip_counts = np.asarray(flip_counts, dtype=float)
    if flip_counts.ndim != 1:
        flip_counts = flip_counts.reshape(-1)
    # Broadcast flip counts to number of links
    if len(flip_counts) < num_links:
        reps = int(np.ceil(num_links / len(flip_counts)))
        flip_counts = np.tile(flip_counts, reps)[:num_links]
    elif len(flip_counts) > num_links:
        flip_counts = flip_counts[:num_links]

    # Extract pivot parameters
    pivot_cfg = cfg.get('pivot', {})
    a = float(pivot_cfg.get('a', 0.0))
    # When sweeping parameters the field 'b' will be set by the sweep script; fall back to 0
    b = float(pivot_cfg.get('b', 0.0))
    k = float(pivot_cfg.get('logistic_k', pivot_cfg.get('logistic_k_values', [0.0])[0]))
    n0 = float(pivot_cfg.get('logistic_n0', 0.0))
    # Gauge coupling
    g_coupling = float(cfg.get('g', 1.0))

    # Compute fractal dimension and pivot weight
    D = logistic_dimension(flip_counts, k, n0)
    gD = pivot_weight(D, a, b)

    # Load kernels for U1, SU2, SU3 if available
    kernel_u1_path = os.path.join(data_dir, 'kernel.npy')
    kernel_su2_path = os.path.join(data_dir, 'kernel_SU2.npy')
    kernel_su3_path = os.path.join(data_dir, 'kernel_SU3.npy')
    # Some kernels may not exist depending on the workflow
    K_u1: Optional[np.ndarray] = None
    K_su2: Optional[np.ndarray] = None
    K_su3: Optional[np.ndarray] = None
    if os.path.exists(kernel_u1_path):
        K_u1 = np.load(kernel_u1_path, allow_pickle=True)
    if os.path.exists(kernel_su2_path):
        K_su2 = np.load(kernel_su2_path, allow_pickle=True)
    if os.path.exists(kernel_su3_path):
        K_su3 = np.load(kernel_su3_path, allow_pickle=True)

    # Compute A for U1
    if K_u1 is not None:
        # Ensure proper shape
        K_u1 = np.asarray(K_u1, dtype=float).reshape(-1)
        # Broadcast or truncate kernel to match number of links
        if len(K_u1) < num_links:
            reps = int(np.ceil(num_links / len(K_u1)))
            K_u1 = np.tile(K_u1, reps)[:num_links]
        elif len(K_u1) > num_links:
            K_u1 = K_u1[:num_links]
        A_u1 = g_coupling * gD * K_u1
        np.save(os.path.join(data_dir, 'A_U1.npy'), A_u1)
    # Compute A for SU2
    if K_su2 is not None:
        # K_su2: shape (num_links, 2, 2)
        K_su2 = np.asarray(K_su2, dtype=float)
        if K_su2.shape[0] != num_links:
            raise ValueError(f"Kernel_SU2 first dimension {K_su2.shape[0]} does not match number of links {num_links}")
        # Expand gD over matrix dimensions: gD[:, None, None]
        A_su2 = g_coupling * gD[:, np.newaxis, np.newaxis] * K_su2
        np.save(os.path.join(data_dir, 'A_SU2.npy'), A_su2)
    # Compute A for SU3
    if K_su3 is not None:
        K_su3 = np.asarray(K_su3, dtype=float)
        if K_su3.shape[0] != num_links:
            raise ValueError(f"Kernel_SU3 first dimension {K_su3.shape[0]} does not match number of links {num_links}")
        A_su3 = g_coupling * gD[:, np.newaxis, np.newaxis] * K_su3
        np.save(os.path.join(data_dir, 'A_SU3.npy'), A_su3)
    return None


if __name__ == '__main__':
    import sys
    cfg = sys.argv[1] if len(sys.argv) > 1 else 'config.yaml'
    main(cfg)